import numpy as np
import pandas as pd
import pylab as plt
import time  # 引入time模块

from matplotlib.ticker import MultipleLocator

ticks1 = time.time()
#显示所有列
pd.set_option('display.max_columns', None)
#显示所有行
pd.set_option('display.max_rows', None)
#设置value的显示长度为100，默认为50
pd.set_option('max_colwidth',100)
fig = plt.figure(1,(10,8),dpi = 400)
path = r'D:\workspace\data\csv\sepdata.csv'
df = pd.read_csv(path)
lon = df["longitude"]
lat = df["latitude"]
# 121.967-122.05  30.2057-30.2664
x_low = 121.967
x_high = 122.05
y_low = 30.2057
y_high = 30.2644
scales = 50
x_size = round((x_high-x_low)/scales,8)
y_size = round((y_high-y_low)/scales,8)
print(x_size)
print(y_size)
#计算指定区域的中心经度和纬度，并返回该区域内的轨迹数据
def interest_area_range(df,x_low,x_high,y_low,y_high):
    cond=(df['longitude'] >= x_low) & (df['longitude'] <= x_high)
    df=df[cond]
    cond=(df['latitude'] >= y_low) & (df['latitude'] <= y_high)
    df=df[cond]
    center_lon=(x_low+x_high)/2
    center_lat=(y_low+y_high)/2
    return df,center_lon,center_lat
df,center_lon,center_lat = interest_area_range(df,x_low,x_high,y_low,y_high)

#计算经度刻度数组
#输入参数：查询区域经度上界、查询区域中心经度、网格经度方向尺寸
def lon_scale_func(x_high,center_lon,size):
    lon_scale_list=[center_lon]
    lon_scale=center_lon
    i=1
    while lon_scale  < x_high:
        lon_scale+=size#中心刻度高侧刻度
        less_lon_scale=lon_scale-2*i*size#中心刻度低侧刻度
        lon_scale_list.append(lon_scale)
        lon_scale_list.append(less_lon_scale)
        i+=1
    lon_scale_list.sort()
    grid_lon_scale=np.array(lon_scale_list)
    del lon_scale_list[0]
    lon_scale_array=np.array(lon_scale_list)
    return grid_lon_scale,lon_scale_array
grid_lon_scale,lon_scale_array = lon_scale_func(x_high,center_lon,x_size)
#lon_scale_array 是经度网格尺度数组

#计算纬度刻度数组
#输入参数：查询区域纬度上界、查询区域中心纬度、网格纬度方向尺寸
#注意这里size可以与网格经度方向的size相同，也可以不同
def lat_scale_func(y_high,center_lat,size):
    lat_scale_list=[center_lat]
    lat_scale=center_lat
    i=1
    while lat_scale < y_high:
        lat_scale+=size#中心刻度高侧刻度
        less_lat_scale=lat_scale-2*i*size#中心刻度低侧刻度
        lat_scale_list.append(lat_scale)
        lat_scale_list.append(less_lat_scale)
        i+=1
    lat_scale_list.sort(reverse=True)
    grid_lat_scale=np.array(lat_scale_list)
    del lat_scale_list[0]
    lat_scale_array=np.array(lat_scale_list)
    return grid_lat_scale,lat_scale_array
grid_lat_scale,lat_scale_array = lat_scale_func(y_high,center_lat,y_size)

#为每个网格编号，即生成网格坐标与序号对应的字典
def index_for_grid(lon_scale_array,lat_scale_array):
    grid_num=int(len(lat_scale_array)*len(lon_scale_array))
    grid_index_list=list(range(grid_num))
    grid_tuple_list=[]
    for h in range(len(lat_scale_array)):
        for v in range(len(lon_scale_array)):
            grid_tuple_list.append((h,v))
    grid_dict=dict(zip(grid_tuple_list,grid_index_list))
    dict_grid=dict(zip(grid_index_list,grid_tuple_list))
    return grid_dict,dict_grid #生成网格字典，网格行列坐标对网格序号
                               #生成字典网格：网格序号对网格行列坐标
grid_dict,dict_grid=index_for_grid(lon_scale_array,lat_scale_array)
print(dict_grid)

# 计算网格中心坐标数组
def grid_center_loc(lon_scale_array, lat_scale_array, grid_locarray):
    grid_center_loc_array = np.zeros((len(lat_scale_array), len(lon_scale_array), 2))
    grid_tuple = []
    grid_center_tuple = []
    for i in range(len(lat_scale_array)):
        for j in range(len(lon_scale_array)):
            low_lon = grid_locarray[i, j][0]
            high_lon = grid_locarray[i, j][1]
            low_lat = grid_locarray[i, j][2]
            high_lat = grid_locarray[i, j][3]
            center_lon = (low_lon + high_lon) / 2
            center_lat = (low_lat + high_lat) / 2
            grid_center_loc_array[i, j, 0] = center_lon
            grid_center_loc_array[i, j, 1] = center_lat
            grid_tuple.append((i, j))
            grid_center_tuple.append((center_lon, center_lat))
    center_tuple_dict = dict(zip(grid_center_tuple, grid_tuple))
    return grid_center_loc_array, center_tuple_dict  # 生成三维的网格中心坐标数组
                                                     # 生成网格中心坐标元祖对网格行列元组的字典
ax = fig.gca()
ax.set_xticks(lon_scale_array)
ax.set_yticks(lat_scale_array)
plt.xticks(rotation=15)
plt.yticks(rotation=15)
plt.xlim(xmin=grid_lon_scale[0],xmax=grid_lon_scale[-1])
plt.ylim(ymin=grid_lat_scale[-1],ymax=grid_lat_scale[0])
plt.scatter(lon,lat,s=0.1)
plt.grid(linewidth=1,color="black")
plt.show()

def getcenter_list(dict_grid,grid_lon_scale,grid_lat_scale):
    centerlist = []
    for i in range(len(dict_grid)):
        grid_row = dict_grid[i][0]
        grid_colum = dict_grid[i][1]
        count = 0
        grid_lon_low = round(grid_lon_scale[grid_colum], 8)
        grid_lon_high = round(grid_lon_scale[grid_colum + 1], 8)
        grid_lat_low = round(grid_lat_scale[grid_row + 1], 8)
        grid_lat_high = round(grid_lat_scale[grid_row], 8)
        grid_center_lon = round((grid_lon_low+grid_lon_high)/2, 8)
        grid_center_lat = round((grid_lat_low+grid_lat_high)/2, 8)
        centerlist.append((grid_center_lon,grid_center_lat))
    return centerlist
print(getcenter_list(dict_grid,grid_lon_scale,grid_lat_scale))

def getpointnums(dict_grid,grid_lon_scale,grid_lat_scale):
    pointnums_list = []
    for i in range(len(dict_grid)):
        grid_row = dict_grid[i][0]
        grid_colum = dict_grid[i][1]
        count = 0
        grid_lon_low = round(grid_lon_scale[grid_colum], 8)
        grid_lon_high = round(grid_lon_scale[grid_colum + 1], 8)
        grid_lat_low = round(grid_lat_scale[grid_row + 1], 8)
        grid_lat_high = round(grid_lat_scale[grid_row], 8)
        for row in df.itertuples():
            lon = row.longitude
            lat = row.latitude
            if grid_lon_low <= lon <= grid_lon_high and grid_lat_low <= lat <= grid_lat_high:
                count = count + 1
        pointnums_list.append(count)
    dict_grid_pointnums = dict(zip(dict_grid, pointnums_list))
    return pointnums_list,dict_grid_pointnums
# pointnums_list,dict_grid_pointnums=getpointnums(dict_grid,grid_lon_scale,grid_lat_scale)

# print(pointnums_list)




colums = len(grid_lon_scale)-1-1
rows = len(grid_lat_scale)-1-1
print("行:",rows)
print("列:",colums)
ticks2 = time.time()
print("程序运行用时：",round(ticks2-ticks1,2),"s")